![]() |
![]() |
|
![]() |
本教程演示了如何使用 Google DeepMind 的 recurrentgemma
库、JAX(高性能数值计算库)、Flax(基于 JAX 的神经网络库)、Chex(用于编写可靠 JAX 代码的实用程序库)、Optax(基于 JAX 的梯度处理和优化库)以及 MTNT(带噪文本的机器翻译)数据集,对 RecurrentGemma 2B Instruct 模型进行微调,以完成英法翻译任务。虽然此记事本中未直接使用 Flax,但 Gemma 是使用 Flax 创建的。
recurrentgemma
库是使用 JAX、Flax、Orbax(一个基于 JAX 的库,用于训练有用程序,例如检查点)和 SentencePiece(一个分词器/解析器库)编写的。
此笔记本可以在搭载 T4 GPU 的 Google Colab 上运行(依次前往修改 > 笔记本设置 > 在硬件加速器下方,选择 T4 GPU)。
设置
以下部分介绍了准备笔记本以使用 RecurrentGemma 模型的步骤,包括模型访问权限、获取 API 密钥和配置笔记本运行时。
为 Gemma 设置 Kaggle 访问权限
如需完成本教程,您首先需要按照类似于 Gemma 设置的设置说明操作,但有以下几点例外:
- 在 kaggle.com 上获取 RecurrentGemma(而非 Gemma)的访问权限。
- 选择一个具有足够资源的 Colab 运行时,以运行 RecurrentGemma 模型。
- 生成并配置 Kaggle 用户名和 API 密钥。
完成 RecurrentGemma 设置后,请继续下一部分,在其中为 Colab 环境设置环境变量。
设置环境变量
为 KAGGLE_USERNAME
和 KAGGLE_KEY
设置环境变量。当系统显示“要授予访问权限吗?”消息时,同意提供 Secret 访问权限。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
安装 recurrentgemma
库
免费的 Colab 硬件加速功能目前insufficient以运行此笔记本。如果您使用的是 Colab 随用随付或 Colab Pro,请依次点击修改 > 笔记本设置 > 选择 A100 GPU > 保存,以启用硬件加速。
接下来,您需要从 github.com/google-deepmind/recurrentgemma
安装 Google DeepMind recurrentgemma
库。如果您收到有关“pip 的依赖项解析器”的错误,通常可以忽略。
pip install -q git+https://github.com/google-deepmind/recurrentgemma.git
导入库
本笔记本使用 Flax(用于神经网络)、核心 JAX、SentencePiece(用于令牌化)、Chex(用于编写可靠 JAX 代码的实用程序库)、Optax(梯度处理和优化库)以及 TensorFlow Dataset。
import pathlib
from typing import Any, Mapping, Iterator
import enum
import functools
import chex
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
import sentencepiece as spm
from recurrentgemma import jax as recurrentgemma
加载 RecurrentGemma 模型
- 使用
kagglehub.model_download
加载 RecurrentGemma 模型,该函数接受三个参数:
handle
:Kaggle 中的模型句柄path
:(可选字符串)本地路径force_download
:(可选布尔值)强制重新下载模型
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:50<00:00, 81.5MB/s] Extracting model files...
print('RECURRENTGEMMA_VARIANT:', RECURRENTGEMMA_VARIANT)
RECURRENTGEMMA_VARIANT: 2b-it
- 检查模型权重和分词器的位置,然后设置路径变量。分词器目录将位于您下载模型的主目录中,而模型权重将位于子目录中。例如:
tokenizer.model
文件将位于/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
中)。- 模型检查点将位于
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
中)。
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
加载并准备 MTNT 数据集和 Gemma 分词器
您将使用 MTNT(带噪文本的机器翻译)数据集,该数据集可从 TensorFlow Datasets 获取。
下载 MTNT 数据集的英语到法语数据集部分,然后抽取两个示例。数据集中的每个示例都包含两个条目:src
:原始英语句子;dst
:相应的法语译文。
ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/35692 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-train.tfrecord*...: 0%| … Generating test examples...: 0%| | 0/1020 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-test.tfrecord*...: 0%| |… Generating valid examples...: 0%| | 0/811 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-valid.tfrecord*...: 0%| … Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data. Example 0: dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".' src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.' Example 1: dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?" src: b'Is Kameron a Little Salty About Her Lack of Air Time?'
加载使用 sentencepiece.SentencePieceProcessor
构建的 Gemma 分词器:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
为英语到法语的翻译任务自定义 SentencePieceProcessor
。由于您将对 RecurrentGemma (Griffin) 模型的英语部分进行微调,因此需要进行一些调整,例如:
输入前缀:为每个输入添加通用前缀可指明翻译任务。例如,您可以使用带有
Translate this into French: [INPUT_SENTENCE]
等前缀的提示。翻译开始后缀:在每个问题末尾添加后缀可指示 Gemma 模型确切何时开始翻译流程。添加新行应该就可以了。
语言模型令牌:RecurrentGemma (Griffin) 模型需要在每个序列的开头添加“序列开头”令牌。同样,您需要在每个训练示例的末尾添加“序列结束”令牌。
围绕 SentencePieceProcessor
构建自定义封装容器,如下所示:
class GriffinTokenizer:
"""A custom wrapper around a SentencePieceProcessor."""
def __init__(self, spm_processor: spm.SentencePieceProcessor):
self._spm_processor = spm_processor
@property
def pad_id(self) -> int:
"""Fast access to the pad ID."""
return self._spm_processor.pad_id()
def tokenize(
self,
example: str | bytes,
prefix: str = '',
suffix: str = '',
add_eos: bool = True,
) -> jax.Array:
"""
A tokenization function.
Args:
example: Input string to tokenize.
prefix: Prefix to add to the input string.
suffix: Suffix to add to the input string.
add_eos: If True, add an end of sentence token at the end of the output
sequence.
Returns:
Tokens corresponding to the input string.
"""
int_list = [self._spm_processor.bos_id()]
int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
if add_eos:
int_list.append(self._spm_processor.eos_id())
return jnp.array(int_list, dtype=jnp.int32)
def tokenize_tf_op(
self,
str_tensor: tf.Tensor,
prefix: str = '',
suffix: str = '',
add_eos: bool = True,
) -> tf.Tensor:
"""A TensforFlow operator for the `tokenize` function."""
encoded = tf.numpy_function(
self.tokenize,
[str_tensor, prefix, suffix, add_eos],
tf.int32)
encoded.set_shape([None])
return encoded
def to_string(self, tokens: jax.Array) -> str:
"""Convert an array of tokens to a string."""
return self._spm_processor.EncodeIds(tokens.tolist())
您可以通过实例化新的自定义 GriffinTokenizer
,然后将其应用于 MTNT 数据集的一小部分样本来试用它:
def tokenize_source(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(
example,
prefix='Translate this into French:\n',
suffix='\n',
add_eos=False
)
def tokenize_destination(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example, add_eos=True)
tokenizer = GriffinTokenizer(vocab)
ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {
'src': tokenize_source(tokenizer, x['src']),
'dst': tokenize_destination(tokenizer, x['dst'])
})
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Example 0: src: [ 2 49688 736 1280 6987 235292 108 651 2778 576 1080 104745 11982 5736 832 8995 901 780 3547 665 575 573 4589 235369 2778 235265 108] dst: [ 2 2025 29653 581 664 16298 1437 55563 41435 7840 581 683 111452 581 533 235303 9776 4108 2459 679 485 235303 479 6728 579 1806 2499 709 29653 581 533 235303 101323 16054 1] Example 1: src: [ 2 49688 736 1280 6987 235292 108 2437 87150 477 476 11709 230461 8045 3636 40268 576 4252 4897 235336 108] dst: [ 2 213606 477 1455 235290 3510 748 8268 191017 2809 581 2032 69972 581 11495 1305 533 235303 65978 1654 1]
为整个 MTNT 数据集构建数据加载器:
@chex.dataclass(frozen=True)
class TrainingInput:
# Input tokens provided to the model.
input_tokens: jax.Array
# A mask that determines which tokens contribute to the target loss
# calculation.
target_mask: jax.Array
class DatasetSplit(enum.Enum):
TRAIN = 'train'
VALIDATION = 'valid'
class MTNTDatasetBuilder:
"""A data loader for the MTNT dataset."""
N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}
BUFFER_SIZE_SHUFFLE = 10_000
TRANSLATION_PREFIX = 'Translate this into French:\n'
TRANSLATION_SUFFIX = '\n'
def __init__(self,
tokenizer : GriffinTokenizer,
max_seq_len: int):
"""A constructor.
Args:
tokenizer: The tokenizer to use.
max_seq_len: The size of each sequence in a given batch.
"""
self._tokenizer = tokenizer
self._base_data = {
DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
}
self._max_seq_len = max_seq_len
def _tokenize_source(self, example: tf.Tensor):
"""A tokenization function for the source."""
return self._tokenizer.tokenize_tf_op(
example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,
add_eos=False
)
def _tokenize_destination(self, example: tf.Tensor):
"""A tokenization function for the French translation."""
return self._tokenizer.tokenize_tf_op(example, add_eos=True)
def _pad_up_to_max_len(self,
input_tensor: tf.Tensor,
pad_value: int | bool,
) -> tf.Tensor:
"""Pad the given tensor up to sequence length of a batch."""
seq_len = tf.shape(input_tensor)[0]
to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
return tf.pad(
input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,
)
def _to_training_input(
self,
src_tokens: jax.Array,
dst_tokens: jax.Array,
) -> TrainingInput:
"""Build a training input from a tuple of source and destination tokens."""
# The input sequence fed to the model is simply the concatenation of the
# source and the destination.
tokens = tf.concat([src_tokens, dst_tokens], axis=0)
# You want to prevent the model from updating based on the source (input)
# tokens. To achieve this, add a target mask to each input.
q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
mask = tf.concat([q_mask, a_mask], axis=0)
# If the output tokens sequence is smaller than the target sequence size,
# then pad it with pad tokens.
tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)
# You don't want to perform the backward on the pad tokens.
mask = self._pad_up_to_max_len(mask, False)
return TrainingInput(input_tokens=tokens, target_mask=mask)
def get_train_dataset(self, batch_size: int, num_epochs: int):
"""Build the training dataset."""
# Tokenize each sample.
ds = self._base_data[DatasetSplit.TRAIN].map(
lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst']))
)
# Convert them to training inputs.
ds = ds.map(lambda x, y: self._to_training_input(x, y))
# Remove the samples which are too long.
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
# Shuffle the dataset.
ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
# Repeat if necessary.
ds = ds.repeat(num_epochs)
# Build batches.
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def get_validation_dataset(self, batch_size: int):
"""Build the validation dataset."""
# Same as the training dataset, but no shuffling and no repetition
ds = self._base_data[DatasetSplit.VALIDATION].map(
lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst']))
)
ds = ds.map(lambda x, y: self._to_training_input(x, y))
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
再次实例化自定义 GriffinTokenizer
,然后将其应用于 MTNT 数据集,并抽取两个示例,以便试用 MTNTDatasetBuilder
:
dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> Example 0: input_tokens: [[ 2 49688 736 1280 6987 235292 108 12583 665 235265 108 2 6151 94975 1320 6238 235265 1 0 0] [ 2 49688 736 1280 6987 235292 108 4899 29960 11270 108282 235265 108 2 4899 79025 11270 108282 1 0] [ 2 49688 736 1280 6987 235292 108 26620 235265 108 2 26620 235265 1 0 0 0 0 0 0]] target_mask: [[False False False False False False False False False False False True True True True True True True False False] [False False False False False False False False False False False False False True True True True True True False] [False False False False False False False False False False True True True True False False False False False False]] Example 1: input_tokens: [[ 2 49688 736 1280 6987 235292 108 527 5174 1683 235336 108 2 206790 581 20726 482 2208 1654 1] [ 2 49688 736 1280 6987 235292 108 28484 235256 235336 108 2 120500 13832 1654 1 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 235324 235304 2705 235265 108 2 235324 235304 19963 235265 1 0 0]] target_mask: [[False False False False False False False False False False False False True True True True True True True True] [False False False False False False False False False False False True True True True True False False False False] [False False False False False False False False False False False False True True True True True True False False]]
配置模型
在开始微调 Gemma 模型之前,您需要对其进行配置。
使用 recurrentgemma.jax.utils.load_parameters
方法加载 RecurrentGemma (Griffin) 模型检查点:
params = recurrentgemma.load_parameters(CKPT_PATH, "single_device")
如需自动从 RecurrentGemma 模型检查点加载正确的配置,请使用 recurrentgemma.GriffinConfig.from_flax_params_or_variables
:
config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
使用 recurrentgemma.jax.Griffin
实例化 Griffin 模型:
model = recurrentgemma.Griffin(config)
在 RecurrentGemma 模型检查点/权重和分词器之上使用 recurrentgemma.jax.Sampler
创建 sampler
,以检查您的模型能否执行翻译:
sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)
微调模型
在此部分中,您将完成以下各项:
- 使用
gemma.deprecated.transformer.Transformer
类创建正向传递和损失函数。 - 为令牌构建位置和注意力遮罩矢量
- 使用 Flax 构建训练步骤函数。
- 构建不包含反向传递的验证步骤。
- 创建训练循环。
- 微调 Gemma 模型。
使用 recurrentgemma.jax.griffin.Griffin
类定义正向传递和损失函数。RecurrentGemma Griffin
继承自 flax.linen.Module
,并提供两种基本方法:
init
:初始化模型的参数。apply
:使用给定的一组参数执行模型的__call__
函数。
由于您使用的是预训练的 Gemma 权重,因此无需使用 init
函数。
def forward_and_loss_fn(
params,
*,
model: recurrentgemma.Griffin,
input_tokens: jax.Array, # Shape [B, L]
input_mask: jax.Array, # Shape [B, L]
positions: jax.Array, # Shape [B, L]
) -> jax.Array:
"""Forward pass and loss function.
Args:
params: model's input parameters.
model: Griffin model to call.
input_tokens: input tokens sequence, shape [B, L].
input_mask: tokens to ignore when computing the loss, shape [B, L].
positions: relative position of each token, shape [B, L].
Returns:
Softmax cross-entropy loss for the next-token prediction task.
"""
batch_size = input_tokens.shape[0]
# Forward pass on the input data.
# No attention cache is needed here.
# Exclude the last step as it does not appear in the targets.
logits, _ = model.apply(
{"params": params},
tokens=input_tokens[:, :-1],
segment_pos=positions[:, :-1],
cache=None,
)
# Similarly, the first token cannot be predicteds.
target_tokens = input_tokens[:, 1:]
target_mask = input_mask[:, 1:]
# Convert the target labels into one-hot encoded vectors.
one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])
# Don't update on unwanted tokens.
one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]
# Normalization factor.
norm_factor = batch_size * (jnp.sum(target_mask) + 1e-8)
# Return the negative log-likelihood loss (NLL) function.
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) / norm_factor
构建用于执行向后传递并相应地更新模型参数的 train_step
函数,其中:
jax.value_and_grad
用于在正向和反向传递期间评估损失函数和梯度。optax.apply_updates
用于更新参数。
Params = Mapping[str, Any]
def get_positions(example: jax.Array, pad_id : int) -> jax.Array:
"""Builds the position vector from the given tokens."""
pad_mask = example != pad_id
positions = jnp.cumsum(pad_mask, axis=-1)
# Subtract one for all positions from the first valid one as they are
# 0-indexed
positions = positions - (positions >= 1)
return positions
@functools.partial(
jax.jit,
static_argnames=['model', 'optimizer'],
donate_argnames=['params', 'opt_state'],
)
def train_step(
model: recurrentgemma.Griffin,
params: Params,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
pad_id: int,
example: TrainingInput,
) -> tuple[jax.Array, Params, optax.OptState]:
"""The train step.
Args:
model: The RecurrentGemma (Griffin) model.
params: The model's input parameters.
optimizer: The Optax optimizer to use.
opt_state: The input optimizer's state.
pad_id: The ID of the pad token.
example: The input batch.
Returns:
Training loss, updated parameters, updated optimizer state.
"""
positions = get_positions(example.input_tokens, pad_id)
# Forward and backward passes.
train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(
params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
)
# Update the parameters.
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return train_loss, params, opt_state
构建不使用反向传递的 validation_step
函数:
@functools.partial(jax.jit, static_argnames=['model'])
def validation_step(
model: recurrentgemma.Griffin,
params: Params,
pad_id: int,
example: TrainingInput,
) -> jax.Array:
return forward_and_loss_fn(
params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=get_positions(example.input_tokens, pad_id),
)
定义训练循环:
def train_loop(
model: recurrentgemma.Griffin,
params: Params,
optimizer: optax.GradientTransformation,
train_ds: Iterator[TrainingInput],
validation_ds: Iterator[TrainingInput],
num_steps: int | None = None,
eval_every_n: int = 20,
):
opt_state = jax.jit(optimizer.init)(params)
step_counter = 0
avg_loss=0
# The first round of the validation loss.
n_steps_eval = 0
eval_loss = 0
for val_example in validation_ds.as_numpy_iterator():
eval_loss += validation_step(
model, params, dataset_builder._tokenizer.pad_id, val_example
)
n_steps_eval += 1
print(f"Start, validation loss: {eval_loss/n_steps_eval}")
for train_example in train_ds:
train_loss, params, opt_state = train_step(
model=model,
params=params,
optimizer=optimizer,
opt_state=opt_state,
pad_id=dataset_builder._tokenizer.pad_id,
example=train_example,
)
step_counter += 1
avg_loss += train_loss
if step_counter % eval_every_n == 0:
eval_loss = 0
n_steps_eval = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += validation_step(
model,
params,
dataset_builder._tokenizer.pad_id,
val_example,
)
n_steps_eval +=1
avg_loss /= eval_every_n
eval_loss /= n_steps_eval
print(f"STEP {step_counter} training loss: {avg_loss} - eval loss: {eval_loss}")
avg_loss=0
if num_steps is not None and step_counter > num_steps:
break
return params
您必须在此处选择一个优化器 (Optax)。对于内存较小的设备,您应使用 SGD,因为它的内存占用量要小得多。为了实现最佳微调性能,请尝试使用 Adam-W。此示例中针对 2b-it
检查点提供了此笔记本中特定任务的每个优化器的最佳超参数。
def griffin_weight_decay_mask(params_like: optax.Params) -> Any:
# Don't put weight decay on the RGLRU, the embeddings and any biases
def enable_weight_decay(path: list[Any], _: Any) -> bool:
# Parameters in the LRU and embedder
path = [dict_key.key for dict_key in path]
if 'rg_lru' in path or 'embedder' in path:
return False
# All biases and scales
if path[-1] in ('b', 'scale'):
return False
return True
return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)
optimizer_choice = "sgd"
if optimizer_choice == "sgd":
optimizer = optax.sgd(learning_rate=1e-3)
num_steps = 300
elif optimizer_choice == "adamw":
optimizer = optax.adamw(
learning_rate=1e-4,
b2=0.96,
eps=1e-8,
weight_decay=0.1,
mask=griffin_weight_decay_mask,
)
num_steps = 100
else:
raise ValueError(f"Unknown optimizer: {optimizer_choice}")
准备训练数据集和验证数据集:
# Choose a small sequence length size, so that everything fits in memory.
num_epochs = 1
batch_size = 1
sequence_length = 32
# Make the dataset builder.
tokenizer = GriffinTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, sequence_length + 1)
# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(
batch_size=batch_size,
num_epochs=num_epochs,
).as_numpy_iterator()
# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(
batch_size=batch_size,
).take(50)
开始对 RecurrentGemma (Griffin) 模型进行微调,并指定步数上限 (num_steps
):
trained_params = train_loop(
model=model,
params=params,
optimizer=optimizer,
train_ds=train_ds,
validation_ds=validation_ds,
num_steps=num_steps,
)
Start, validation loss: 7.894117832183838 /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,33]), ShapedArray(bool[1,33]), ShapedArray(int32[], weak_type=True). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" STEP 20 training loss: 4.592616081237793 - eval loss: 2.847407102584839 STEP 40 training loss: 2.7537424564361572 - eval loss: 2.9258534908294678 STEP 60 training loss: 2.835618257522583 - eval loss: 2.4382340908050537 STEP 80 training loss: 2.6322107315063477 - eval loss: 2.3696839809417725 STEP 100 training loss: 1.8703256845474243 - eval loss: 2.355681896209717 STEP 120 training loss: 2.7280433177948 - eval loss: 2.4059958457946777 STEP 140 training loss: 2.3047447204589844 - eval loss: 2.083082914352417 STEP 160 training loss: 2.3432137966156006 - eval loss: 2.095074415206909 STEP 180 training loss: 2.1081202030181885 - eval loss: 2.006460189819336 STEP 200 training loss: 2.5359647274017334 - eval loss: 1.9667452573776245 STEP 220 training loss: 2.202195644378662 - eval loss: 1.9440618753433228 STEP 240 training loss: 2.756615400314331 - eval loss: 2.1073737144470215 STEP 260 training loss: 2.5128934383392334 - eval loss: 2.117241859436035 STEP 280 training loss: 2.73045015335083 - eval loss: 1.9159646034240723 STEP 300 training loss: 2.0918595790863037 - eval loss: 1.9742532968521118
训练损失和验证损失应该会随着步数的增加而降低。
为确保输入内容与训练格式一致,请务必使用前缀 Translate this into French:\n
并在末尾添加换行符。这会指示模型开始翻译。
sampler.params = trained_params
output = sampler(
["Translate this into French:\nHello, my name is Morgane.\n"],
total_generation_steps=100,
)
print(output.text[0])
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,16]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Mais je m'appelle Morgane.
了解详情
- 您可以详细了解 Google DeepMind GitHub 上的
recurrentgemma
库,其中包含您在本教程中使用的方法和模块的 docstring,例如recurrentgemma.jax.load_parameters
、recurrentgemma.jax.Griffin
和recurrentgemma.jax.Sampler
。 - 以下库有自己的文档网站:核心 JAX、Flax、Chex、Optax 和 Orbax。
- 如需查看
sentencepiece
分词器/解析器文档,请参阅 Google 的sentencepiece
GitHub 代码库。 - 如需查看
kagglehub
文档,请参阅 Kaggle 的kagglehub
GitHub 代码库中的README.md
。 - 了解如何将 Gemma 模型与 Google Cloud Vertex AI 搭配使用。
- 如果您使用的是 Google Cloud TPU(v3-8 及更高版本),请务必也更新到最新的
jax[tpu]
软件包 (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
),重启运行时,并检查jax
和jaxlib
版本是否匹配 (!pip list | grep jax
)。这可以防止因jaxlib
和jax
版本不匹配而导致的RuntimeError
。如需了解有关 JAX 的更多安装说明,请参阅 JAX 文档。 - 请参阅 Google DeepMind 撰写的 RecurrentGemma: Moving Past Transformers for Efficient Open Language Models 论文。
- 阅读 Google DeepMind 的 Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models 论文,详细了解 RecurrentGemma 使用的模型架构。